#
# Copyright (c) 2019 UAVCAN Development Team
# This software is distributed under the terms of the MIT License.
# Author: Pavel Kirienko <pavel.kirienko@zubax.com>
#

from __future__ import annotations
import copy
import typing
import asyncio
import logging
import dataclasses
import pyuavcan
from pyuavcan.transport.commons.high_overhead_transport import TransferReassembler
from .._frame import SerialFrame
from ._base import SerialSession


_logger = logging.getLogger(__name__)


@dataclasses.dataclass
class SerialInputSessionStatistics(pyuavcan.transport.SessionStatistics):
    mismatched_data_type_hashes: typing.Dict[int, int] = dataclasses.field(default_factory=dict)
    """
    Keys are data type hash values collected from received frames that did not match the local type configuration.
    Values are the number of times each hash value has been encountered.
    """

    reassembly_errors_per_source_node_id: typing.Dict[int, typing.Dict[TransferReassembler.Error, int]] = \
        dataclasses.field(default_factory=dict)
    """
    Keys are source node-IDs; values are dicts where keys are error enum members and values are counts.
    """


class SerialInputSession(SerialSession, pyuavcan.transport.InputSession):
    DEFAULT_TRANSFER_ID_TIMEOUT = 2.0
    """
    Units are seconds. Can be overridden after instantiation if needed.
    """

    def __init__(self,
                 specifier:        pyuavcan.transport.InputSessionSpecifier,
                 payload_metadata: pyuavcan.transport.PayloadMetadata,
                 loop:             asyncio.AbstractEventLoop,
                 finalizer:        typing.Callable[[], None]):
        """
        Do not call this directly.
        Instead, use the factory method :meth:`pyuavcan.transport.serial.SerialTransport.get_input_session`.
        """
        self._specifier = specifier
        self._payload_metadata = payload_metadata
        self._loop = loop
        assert self._loop is not None

        if not isinstance(self._specifier, pyuavcan.transport.InputSessionSpecifier) or \
                not isinstance(self._payload_metadata, pyuavcan.transport.PayloadMetadata):  # pragma: no cover
            raise TypeError('Invalid parameters')

        self._statistics = SerialInputSessionStatistics()
        self._transfer_id_timeout = self.DEFAULT_TRANSFER_ID_TIMEOUT
        self._queue: asyncio.Queue[pyuavcan.transport.TransferFrom] = asyncio.Queue()
        self._reassemblers: typing.Dict[int, TransferReassembler] = {}

        super(SerialInputSession, self).__init__(finalizer)

    def _process_frame(self, frame: SerialFrame) -> None:
        """
        This is a part of the transport-internal API. It's a public method despite the name because Python's
        visibility handling capabilities are limited. I guess we could define a private abstract base to
        handle this but it feels like too much work. Why can't we have protected visibility in Python?
        """
        assert frame.data_specifier == self._specifier.data_specifier, 'Internal protocol violation'
        self._statistics.frames += 1

        if frame.data_type_hash != self._payload_metadata.data_type_hash:
            self._statistics.errors += 1
            try:
                self._statistics.mismatched_data_type_hashes[frame.data_type_hash] += 1
            except LookupError:
                self._statistics.mismatched_data_type_hashes[frame.data_type_hash] = 1
            return

        transfer: typing.Optional[pyuavcan.transport.TransferFrom]
        if frame.source_node_id is None:
            transfer = TransferReassembler.construct_anonymous_transfer(frame)
            if transfer is None:
                self._statistics.errors += 1
                _logger.debug('%s: Invalid anonymous frame: %s', self, frame)
        else:
            transfer = self._get_reassembler(frame.source_node_id).process_frame(frame, self._transfer_id_timeout)

        if transfer is not None:
            self._statistics.transfers += 1
            self._statistics.payload_bytes += sum(map(len, transfer.fragmented_payload))
            _logger.debug('%s: Received transfer: %s; current stats: %s', self, transfer, self._statistics)
            try:
                self._queue.put_nowait(transfer)
            except asyncio.QueueFull:  # pragma: no cover
                # TODO: make the queue capacity configurable
                self._statistics.drops += len(transfer.fragmented_payload)

    async def receive_until(self, monotonic_deadline: float) -> typing.Optional[pyuavcan.transport.TransferFrom]:
        try:
            timeout = monotonic_deadline - self._loop.time()
            if timeout > 0:
                transfer = await asyncio.wait_for(self._queue.get(), timeout, loop=self._loop)
            else:
                transfer = self._queue.get_nowait()
        except (asyncio.TimeoutError, asyncio.QueueEmpty):
            # If there are unprocessed transfers, allow the caller to read them even if the instance is closed.
            self._raise_if_closed()
            return None
        else:
            assert isinstance(transfer, pyuavcan.transport.TransferFrom), 'Internal protocol violation'
            assert transfer.source_node_id == self._specifier.remote_node_id or self._specifier.remote_node_id is None
            return transfer

    @property
    def transfer_id_timeout(self) -> float:
        return self._transfer_id_timeout

    @transfer_id_timeout.setter
    def transfer_id_timeout(self, value: float) -> None:
        if value > 0:
            self._transfer_id_timeout = float(value)
        else:
            raise ValueError(f'Invalid value for transfer-ID timeout [second]: {value}')

    @property
    def specifier(self) -> pyuavcan.transport.InputSessionSpecifier:
        return self._specifier

    @property
    def payload_metadata(self) -> pyuavcan.transport.PayloadMetadata:
        return self._payload_metadata

    def sample_statistics(self) -> SerialInputSessionStatistics:
        return copy.copy(self._statistics)

    def _get_reassembler(self, source_node_id: int) -> TransferReassembler:
        try:
            return self._reassemblers[source_node_id]
        except LookupError:
            def on_reassembly_error(error: TransferReassembler.Error) -> None:
                self._statistics.errors += 1
                d = self._statistics.reassembly_errors_per_source_node_id[source_node_id]
                try:
                    d[error] += 1
                except LookupError:
                    d[error] = 1

            self._statistics.reassembly_errors_per_source_node_id.setdefault(source_node_id, {})
            reasm = TransferReassembler(source_node_id=source_node_id,
                                        max_payload_size_bytes=self._payload_metadata.max_size_bytes,
                                        on_error_callback=on_reassembly_error)
            self._reassemblers[source_node_id] = reasm
            _logger.debug('%s: New %s (%d total)', self, reasm, len(self._reassemblers))
            return reasm


# noinspection PyProtectedMember
def _unittest_input_session() -> None:
    import asyncio
    from pytest import raises, approx
    from pyuavcan.transport import InputSessionSpecifier, MessageDataSpecifier, Priority, TransferFrom
    from pyuavcan.transport import PayloadMetadata, Timestamp
    from pyuavcan.transport.commons.high_overhead_transport import TransferCRC

    ts = Timestamp.now()
    prio = Priority.SLOW
    dst_nid = 1234

    run_until_complete = asyncio.get_event_loop().run_until_complete
    get_monotonic = asyncio.get_event_loop().time

    nihil_supernum = b'nihil supernum'

    finalized = False

    def do_finalize() -> None:
        nonlocal finalized
        finalized = True

    session_spec = InputSessionSpecifier(MessageDataSpecifier(12345), None)
    payload_meta = PayloadMetadata(0xdead_beef_bad_c0ffe, 100)

    sis = SerialInputSession(specifier=session_spec,
                             payload_metadata=payload_meta,
                             loop=asyncio.get_event_loop(),
                             finalizer=do_finalize)
    assert sis.specifier == session_spec
    assert sis.payload_metadata == payload_meta
    assert sis.sample_statistics() == SerialInputSessionStatistics()

    assert sis.transfer_id_timeout == approx(SerialInputSession.DEFAULT_TRANSFER_ID_TIMEOUT)
    sis.transfer_id_timeout = 1.0
    with raises(ValueError):
        sis.transfer_id_timeout = 0.0
    assert sis.transfer_id_timeout == approx(1.0)

    assert run_until_complete(sis.receive_until(get_monotonic() + 0.1)) is None
    assert run_until_complete(sis.receive_until(0.0)) is None

    def mk_frame(transfer_id:       int,
                 index:             int,
                 end_of_transfer:   bool,
                 payload:           typing.Union[bytes, memoryview],
                 source_node_id:    typing.Optional[int]) -> SerialFrame:
        return SerialFrame(timestamp=ts,
                           priority=prio,
                           transfer_id=transfer_id,
                           index=index,
                           end_of_transfer=end_of_transfer,
                           payload=memoryview(payload),
                           source_node_id=source_node_id,
                           destination_node_id=dst_nid,
                           data_specifier=session_spec.data_specifier,
                           data_type_hash=payload_meta.data_type_hash)

    # ANONYMOUS TRANSFERS.
    sis._process_frame(mk_frame(transfer_id=0,
                                index=0,
                                end_of_transfer=False,
                                payload=nihil_supernum,
                                source_node_id=None))
    assert sis.sample_statistics() == SerialInputSessionStatistics(
        frames=1,
        errors=1,
    )

    sis._process_frame(mk_frame(transfer_id=0,
                                index=1,
                                end_of_transfer=True,
                                payload=nihil_supernum,
                                source_node_id=None))
    assert sis.sample_statistics() == SerialInputSessionStatistics(
        frames=2,
        errors=2,
    )

    sis._process_frame(mk_frame(transfer_id=0,
                                index=0,
                                end_of_transfer=True,
                                payload=nihil_supernum,
                                source_node_id=None))
    assert sis.sample_statistics() == SerialInputSessionStatistics(
        transfers=1,
        frames=3,
        payload_bytes=len(nihil_supernum),
        errors=2,
    )
    assert run_until_complete(sis.receive_until(0)) == \
        TransferFrom(timestamp=ts,
                     priority=prio,
                     transfer_id=0,
                     fragmented_payload=[memoryview(nihil_supernum)],
                     source_node_id=None)
    assert run_until_complete(sis.receive_until(get_monotonic() + 0.1)) is None
    assert run_until_complete(sis.receive_until(0.0)) is None

    # BAD DATA TYPE HASH.
    sis._process_frame(
        SerialFrame(timestamp=ts,
                    priority=prio,
                    transfer_id=0,
                    index=0,
                    end_of_transfer=True,
                    payload=memoryview(nihil_supernum),
                    source_node_id=None,
                    destination_node_id=None,
                    data_specifier=session_spec.data_specifier,
                    data_type_hash=0xbad_bad_bad_bad_bad)
    )
    assert sis.sample_statistics() == SerialInputSessionStatistics(
        transfers=1,
        frames=4,
        payload_bytes=len(nihil_supernum),
        errors=3,
        mismatched_data_type_hashes={0xbad_bad_bad_bad_bad: 1},
    )

    # VALID TRANSFERS. Notice that they are unordered on purpose. The reassembler can deal with that.
    sis._process_frame(mk_frame(transfer_id=0,
                                index=1,
                                end_of_transfer=False,
                                payload=nihil_supernum,
                                source_node_id=1111))

    sis._process_frame(mk_frame(transfer_id=0,
                                index=0,
                                end_of_transfer=True,
                                payload=nihil_supernum,
                                source_node_id=2222))       # COMPLETED FIRST

    assert sis.sample_statistics() == SerialInputSessionStatistics(
        transfers=2,
        frames=6,
        payload_bytes=len(nihil_supernum) * 2,
        errors=3,
        mismatched_data_type_hashes={0xbad_bad_bad_bad_bad: 1},
        reassembly_errors_per_source_node_id={
            1111: {},
            2222: {},
        },
    )

    sis._process_frame(mk_frame(transfer_id=0,
                                index=3,
                                end_of_transfer=True,
                                payload=TransferCRC.new(nihil_supernum * 3).value_as_bytes,
                                source_node_id=1111))

    sis._process_frame(mk_frame(transfer_id=0,
                                index=0,
                                end_of_transfer=False,
                                payload=nihil_supernum,
                                source_node_id=1111))

    sis._process_frame(mk_frame(transfer_id=0,
                                index=2,
                                end_of_transfer=False,
                                payload=nihil_supernum,
                                source_node_id=1111))       # COMPLETED SECOND

    assert sis.sample_statistics() == SerialInputSessionStatistics(
        transfers=3,
        frames=9,
        payload_bytes=len(nihil_supernum) * 5,
        errors=3,
        mismatched_data_type_hashes={0xbad_bad_bad_bad_bad: 1},
        reassembly_errors_per_source_node_id={
            1111: {},
            2222: {},
        },
    )

    assert run_until_complete(sis.receive_until(0)) == \
        TransferFrom(timestamp=ts,
                     priority=prio,
                     transfer_id=0,
                     fragmented_payload=[memoryview(nihil_supernum)],
                     source_node_id=2222)
    assert run_until_complete(sis.receive_until(0)) == \
        TransferFrom(timestamp=ts,
                     priority=prio,
                     transfer_id=0,
                     fragmented_payload=[memoryview(nihil_supernum)] * 3,
                     source_node_id=1111)
    assert run_until_complete(sis.receive_until(get_monotonic() + 0.1)) is None
    assert run_until_complete(sis.receive_until(0.0)) is None

    # TRANSFERS WITH REASSEMBLY ERRORS.
    sis._process_frame(mk_frame(transfer_id=1,          # EMPTY IN MULTIFRAME
                                index=0,
                                end_of_transfer=False,
                                payload=b'',
                                source_node_id=1111))

    sis._process_frame(mk_frame(transfer_id=2,          # EMPTY IN MULTIFRAME
                                index=0,
                                end_of_transfer=False,
                                payload=b'',
                                source_node_id=1111))

    assert sis.sample_statistics() == SerialInputSessionStatistics(
        transfers=3,
        frames=11,
        payload_bytes=len(nihil_supernum) * 5,
        errors=5,
        mismatched_data_type_hashes={0xbad_bad_bad_bad_bad: 1},
        reassembly_errors_per_source_node_id={
            1111: {
                TransferReassembler.Error.MULTIFRAME_EMPTY_FRAME: 2,
            },
            2222: {},
        },
    )

    assert not finalized
    sis.close()
    assert finalized
    sis.close()     # Idempotency check
